import pandas as pd


def analyze_dataset(file_path, dataset_name):
    """
    分析数据集中正负样本的比例

    Args:
        file_path (str): 数据集文件路径
        dataset_name (str): 数据集名称（用于打印）
    """
    # 读取数据
    df = pd.read_csv(file_path)

    # 统计正负样本数量
    positive_count = (df['Attrition'] == 1).sum()
    negative_count = (df['Attrition'] == 0).sum()
    total_count = len(df)

    # 计算比例
    positive_ratio = positive_count / total_count * 100
    negative_ratio = negative_count / total_count * 100

    # 打印结果
    print(f"{dataset_name} 数据集:")
    print(f"  总样本数: {total_count}")
    print(f"  正例 (Attrition=1): {positive_count} ({positive_ratio:.2f}%)")
    print(f"  负例 (Attrition=0): {negative_count} ({negative_ratio:.2f}%)")
    print("-" * 40)


if __name__ == "__main__":
    # 分析训练数据
    try:
        analyze_dataset('../data/train.csv', '训练')
    except Exception as e:
        print(f"无法读取 train.csv: {e}")

    # 分析测试数据
    try:
        analyze_dataset('../data/test.csv', '测试')
    except Exception as e:
        print(f"无法读取 test.csv: {e}")
